{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SV-DKL with Pyro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gpleiss/anaconda3/envs/gpytorch/lib/python3.7/site-packages/matplotlib/__init__.py:999: UserWarning: Duplicate key in file \"/home/gpleiss/.dotfiles/matplotlib/matplotlibrc\", line #57\n",
      "  (fname, cnt))\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "import pyro\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Make plots inline\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os.path\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "if not os.path.isfile('3droad.mat'):\n",
    "    print('Downloading \\'3droad\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://www.dropbox.com/s/f6ow1i59oqx05pl/3droad.mat?dl=1', '3droad.mat')\n",
    "    \n",
    "data = torch.Tensor(loadmat('3droad.mat')['data'])\n",
    "X = data[:, :-1]\n",
    "X = X - X.min(0)[0]\n",
    "X = 2 * (X / X.max(0)[0]) - 1\n",
    "y = data[:, -1]\n",
    "\n",
    "# Use the first 80% of the data for training, and the last 20% for testing.\n",
    "train_n = int(floor(0.8*len(X)))\n",
    "\n",
    "train_x = X[:train_n, :].contiguous().cuda()\n",
    "train_y = y[:train_n].contiguous().cuda()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous().cuda()\n",
    "test_y = y[train_n:].contiguous().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dim = train_x.size(-1)\n",
    "\n",
    "class LargeFeatureExtractor(torch.nn.Sequential):           \n",
    "    def __init__(self):                                      \n",
    "        super(LargeFeatureExtractor, self).__init__()        \n",
    "        self.add_module('linear1', torch.nn.Linear(data_dim, 1000))\n",
    "        self.add_module('bn1', torch.nn.BatchNorm1d(1000))\n",
    "        self.add_module('relu1', torch.nn.ReLU())                  \n",
    "        self.add_module('linear2', torch.nn.Linear(1000, 500))\n",
    "        self.add_module('bn2', torch.nn.BatchNorm1d(500))\n",
    "        self.add_module('relu2', torch.nn.ReLU())                  \n",
    "        self.add_module('linear3', torch.nn.Linear(500, 50))\n",
    "        self.add_module('bn3', torch.nn.BatchNorm1d(50))\n",
    "        self.add_module('relu3', torch.nn.ReLU())                  \n",
    "        self.add_module('linear4', torch.nn.Linear(50, 2))         \n",
    "                                                             \n",
    "feature_extractor = LargeFeatureExtractor().cuda()\n",
    "# num_features is the number of final features extracted by the neural network, in this case 2.\n",
    "num_features = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import PyroVariationalGP\n",
    "from gpytorch.variational import CholeskyVariationalDistribution, GridInterpolationVariationalStrategy\n",
    "\n",
    "\n",
    "class PyroSVDKLGridInterpModel(PyroVariationalGP):\n",
    "    def __init__(self, likelihood, grid_size=32, grid_bounds=[(-1, 1), (-1, 1)], name_prefix=\"svdkl_grid_example\"):\n",
    "        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=(grid_size ** num_features))\n",
    "        variational_strategy = GridInterpolationVariationalStrategy(self,\n",
    "                                                                    grid_size=grid_size,\n",
    "                                                                    grid_bounds=grid_bounds,\n",
    "                                                                    variational_distribution=variational_distribution)\n",
    "        super(PyroSVDKLGridInterpModel, self).__init__(variational_strategy,\n",
    "                                                       likelihood,\n",
    "                                                       num_data=train_y.numel(),\n",
    "                                                       name_prefix=name_prefix)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(\n",
    "            lengthscale_prior=gpytorch.priors.SmoothedBoxPrior(0.001, 1., sigma=0.1, transform=torch.exp\n",
    "        ))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DKLModel(gpytorch.Module):\n",
    "    def __init__(self, likelihood, feature_extractor, num_features, grid_bounds=(-1., 1.)):\n",
    "        super(DKLModel, self).__init__()\n",
    "        self.feature_extractor = feature_extractor\n",
    "        self.gp_layer = PyroSVDKLGridInterpModel(likelihood)\n",
    "        self.grid_bounds = grid_bounds\n",
    "        self.num_features = num_features\n",
    "\n",
    "    def features(self, x):\n",
    "        features = self.feature_extractor(x)\n",
    "        features = gpytorch.utils.grid.scale_to_bounds(features, self.grid_bounds[0], self.grid_bounds[1])\n",
    "        return features\n",
    "    \n",
    "    def forward(self, x):\n",
    "        res = self.gp_layer(self.features(x))\n",
    "        return res\n",
    "    \n",
    "    def guide(self, x, y):\n",
    "        self.gp_layer.guide(self.features(x), y)\n",
    "    \n",
    "    def model(self, x, y):\n",
    "        pyro.module(self.gp_layer.name_prefix + \".feature_extractor\", self.feature_extractor)\n",
    "        self.gp_layer.model(self.features(x), y)\n",
    "\n",
    "\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()\n",
    "model = DKLModel(likelihood, feature_extractor, num_features=num_features).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pyro import optim\n",
    "from pyro import infer\n",
    "\n",
    "optimizer = optim.Adam({\"lr\": 0.1})\n",
    "\n",
    "elbo = infer.Trace_ELBO(num_particles=256, vectorize_particles=True)\n",
    "svi = infer.SVI(model.model, model.guide, optimizer, elbo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1   Loss 1361756.0124511719\n",
      "Epoch 2   Loss 1443523.3786621094\n",
      "Epoch 3   Loss 1332343.6667480469\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 3\n",
    "# Not enough for this model to converge, but enough for a fast example\n",
    "\n",
    "for i in range(num_epochs):\n",
    "    # Within each iteration, we will go over each minibatch of data\n",
    "    for minibatch_i, (x_batch, y_batch) in enumerate(train_loader):\n",
    "        loss = svi.step(x_batch, y_batch)\n",
    "    print('Epoch {}   Loss {}'.format(i + 1, loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "with torch.no_grad():\n",
    "    preds = model(test_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test MAE: 10.290990829467773\n"
     ]
    }
   ],
   "source": [
    "print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
